// vim: syntax=P4
/*
    ConQuest: Fine-Grained Queue Measurement in the Data Plane
    
    Copyright (C) 2020 Xiaoqi Chen, Princeton University
    xiaoqic [at] cs.princeton.edu / https://doi.org/10.1145/3359989.3365408
    
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.
    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
*/
{% if False %}
    #error This is a code template, not a P4 program. Please use the accompanying code generator. 
{% endif %}


//== Parameters for ConQuest time-windowing
// number of snaps
// CQ_H = {{CQ_H}}
// LOG_CQ_H = {{LOG_CQ_H}}

// snap CMS #rows
// CQ_R = {{CQ_R}}

// snap CMS #cols
// CQ_C = {{CQ_C}}
// LOG_CQ_C = {{LOG_CQ_C}}

// Time window T in nanoseconds
// CQ_T = {{CQ_T}}
// LOG_CQ_T = {{LOG_CQ_T}}

{% if COUNT_PACKETS %}
#define SKETCH_INC ((bit<32>) 1)
{% else %}//count bytes
#define SKETCH_INC ((bit<32>) hdr.ipv4.total_len)
{% endif %}


// sanity check for parameters
#if ({{CQ_H}} != (1<<{{LOG_CQ_H}}))
  #error CQ_H LOG_CQ_H mismatch
#endif
#if ({{CQ_C}} != (1<<{{LOG_CQ_C}}))
  #error CQ_C LOG_CQ_C mismatch
#endif
#if ({{CQ_T}} != (1<<{{LOG_CQ_T}}))
  #error CQ_T LOG_CQ_T mismatch
#endif

#if ({{CQ_C}}> {{ (CQ_T/6.7) | int }} )
  #warning Snapshot might be too wide, you want at least CQ_C packets per CQ_T nanoseconds to do cyclic cleaning adequately
  //100Gbps=148.8Mpps=6.7ns/pkt
#endif


//== Preamble: constants, headers

#include <core.p4>
#include <tna.p4>

typedef bit<48> mac_addr_t;
typedef bit<32> ipv4_addr_t;
typedef bit<16> ether_type_t;
const ether_type_t ETHERTYPE_IPV4 = 16w0x0800;
const ether_type_t ETHERTYPE_VLAN = 16w0x0810;

typedef bit<8> ip_protocol_t;
const ip_protocol_t IP_PROTOCOLS_ICMP = 1;
const ip_protocol_t IP_PROTOCOLS_TCP = 6;
const ip_protocol_t IP_PROTOCOLS_UDP = 17;


header ethernet_h {
    mac_addr_t dst_addr;
    mac_addr_t src_addr;
    bit<16> ether_type;
}

header ipv4_h {
    bit<4> version;
    bit<4> ihl;
    bit<6> diffserv;
    bit<2> ecn; 
    bit<16> total_len;
    bit<16> identification;
    bit<3> flags;
    bit<13> frag_offset;
    bit<8> ttl;
    bit<8> protocol;
    bit<16> hdr_checksum;
    ipv4_addr_t src_addr;
    ipv4_addr_t dst_addr;
}

header tcp_h {
    bit<16> src_port;
    bit<16> dst_port;
    
    bit<32> seq_no;
    bit<32> ack_no;
    bit<4> data_offset;
    bit<4> res;
    bit<8> flags;
    bit<16> window;
    bit<16> checksum;
    bit<16> urgent_ptr;
}

header udp_h {
    bit<16> src_port;
    bit<16> dst_port;
    bit<16> hdr_lenght;
    bit<16> checksum;
}

struct header_t {
    ethernet_h ethernet;
    ipv4_h ipv4;
    tcp_h tcp;
    udp_h udp;
}

struct paired_32bit {
    bit<32> hi;
    bit<32> lo;
}


//== Metadata definition

struct ig_metadata_t {
}
struct eg_metadata_t {
    bit<8> num_snapshots_to_read;//result of division (delay/T), could be larger than H
    bit<{{LOG_CQ_H}}> snap_epoch;
    bit<18> q_delay;
    
    bit<8> random_bits;
    
    bit<{{LOG_CQ_C}}> cyclic_index;
    {% for rowID in range(CQ_R) %}
        bit<{{LOG_CQ_C}}> hashed_index_row_{{rowID}};
    {% endfor %}
    
    {% for snapID in range(CQ_H) %}
        {% for rowID in range(CQ_R) %}
            bit<{{LOG_CQ_C}}> snap_{{snapID}}_row_{{rowID}}_index;
            bit<32> snap_{{snapID}}_row_{{rowID}}_read;
        {% endfor %}
    {% endfor %}    
    
    {% for snapID in range(CQ_H) %}
        bit<32> snap_{{snapID}}_read_min_l0;
    {% endfor %}  
    
    {% for layer in range(LOG_CQ_H) %}
        {% set skip = 2**layer %}
        {% for j in range(0,CQ_H, skip*2) %}
            bit<32> snap_{{j}}_read_min_l{{layer+1}};
        {% endfor %}  
    {% endfor %}  
    {% set snap_read_sum =  "eg_md.snap_" ~ 0 ~ "_read_min_l" ~ LOG_CQ_H %}
}


//== Parser and deparser

parser TofinoIngressParser(
        packet_in pkt,
        inout ig_metadata_t ig_md,
        out ingress_intrinsic_metadata_t ig_intr_md) {
    state start {
        pkt.extract(ig_intr_md);
        transition select(ig_intr_md.resubmit_flag) {
            1 : parse_resubmit;
            0 : parse_port_metadata;
        }
    }

    state parse_resubmit {
        // Parse resubmitted packet here.
        pkt.advance(64); 
        transition accept;
    }

    state parse_port_metadata {
        pkt.advance(64);  //tofino 1 port metadata size
        transition accept;
    }
}

parser EtherIPTCPUDPParser(
    packet_in pkt,
    out header_t hdr) {
    state start {
        transition parse_ethernet;
    }
    state parse_ethernet {
        pkt.extract(hdr.ethernet);
        transition select (hdr.ethernet.ether_type) {
            ETHERTYPE_IPV4 : parse_ipv4;
            default : reject;
        }
    }
    state parse_ipv4 {
        pkt.extract(hdr.ipv4);
        transition select(hdr.ipv4.protocol) {
            IP_PROTOCOLS_TCP : parse_tcp;
            IP_PROTOCOLS_UDP : parse_udp;
            default : accept;
        }
    }
    state parse_tcp {
        pkt.extract(hdr.tcp);
        transition select(hdr.ipv4.total_len) {
            default : accept;
        }
    }
    state parse_udp {
        pkt.extract(hdr.udp);
        transition select(hdr.udp.dst_port) {
            default: accept;
        }
    }  
}

parser SwitchIngressParser(
        packet_in pkt,
        out header_t hdr,
        out ig_metadata_t ig_md,
        out ingress_intrinsic_metadata_t ig_intr_md) {

    TofinoIngressParser() tofino_parser;
    EtherIPTCPUDPParser() layer4_parser;

    state start {
        tofino_parser.apply(pkt, ig_md, ig_intr_md);
        layer4_parser.apply(pkt, hdr);
        transition accept;
    }
}

control SwitchIngressDeparser(
        packet_out pkt,
        inout header_t hdr,
        in ig_metadata_t ig_md,
        in ingress_intrinsic_metadata_for_deparser_t ig_intr_dprsr_md) {    
    apply {    
        pkt.emit(hdr);
    }
}

parser SwitchEgressParser(
        packet_in pkt,
        out header_t hdr,
        out eg_metadata_t eg_md,
        out egress_intrinsic_metadata_t eg_intr_md) {
    
    EtherIPTCPUDPParser() layer4_parser;
    state start {
        pkt.extract(eg_intr_md);
        layer4_parser.apply(pkt, hdr);
        transition accept;
    }
}

control SwitchEgressDeparser(
        packet_out pkt,
        inout header_t hdr,
        in eg_metadata_t eg_md,
        in egress_intrinsic_metadata_for_deparser_t eg_intr_md_for_dprsr) {
    apply {
        pkt.emit(hdr);
    }
}


//== Control logic 

control SwitchIngress(
        inout header_t hdr,
        inout ig_metadata_t ig_md,
        in ingress_intrinsic_metadata_t ig_intr_md,
        in ingress_intrinsic_metadata_from_parser_t ig_intr_prsr_md,
        inout ingress_intrinsic_metadata_for_deparser_t ig_intr_dprsr_md,
        inout ingress_intrinsic_metadata_for_tm_t ig_intr_tm_md) {
         
        action drop() {
            ig_intr_dprsr_md.drop_ctl = 0x1; // Mark packet for dropping after ingress.
        }
        action drop_and_exit(){
            drop();exit;  // Stop pipeline processing, and drop packet.
        }
        action nop() {
        }
        
        action route_to_port(bit<9> port){
            ig_intr_tm_md.ucast_egress_port=port;
        }
        action reflect(){
            //send you back to where you're from
            ig_intr_tm_md.ucast_egress_port=ig_intr_md.ingress_port;
        }
           
        table tb_route_ipv4 {
            key = {
                hdr.ipv4.dst_addr : exact;
            }
            actions = {
                route_to_port;
                reflect;
                drop;
            }
            default_action = reflect();
        }
        
        apply {
            tb_route_ipv4.apply();            
        }
}

control SwitchEgress(
        inout header_t hdr,
        inout eg_metadata_t eg_md,
        in egress_intrinsic_metadata_t eg_intr_md,
        in egress_intrinsic_metadata_from_parser_t eg_intr_md_from_prsr,
        inout egress_intrinsic_metadata_for_deparser_t ig_intr_dprs_md,
        inout egress_intrinsic_metadata_for_output_port_t eg_intr_oport_md) {
    
    action nop(){
    }
    action drop(){
        ig_intr_dprs_md.drop_ctl = 0x1;
    }
    action skip(){
        exit;
    }
    action mark_ECN(){
        hdr.ipv4.ecn=0x3;
    }
    
    
    // Limit to only process traffic for a single egress port.
    // This prototype is single-port version; multi-port version needs memory partitioning.
    action run_conquest(){
        nop();
    }
    table tb_gatekeeper {
        key = {
            eg_intr_md.egress_port: exact;
        }
        actions = {
            skip;
            run_conquest;
        }
        size = 1;
        default_action = skip();
    }
    
    
    //== Start: What time is it? How long is the queue?
    action prep_epochs(){
        bit<18> q_delay=eg_intr_md.deq_timedelta;
        eg_md.q_delay=q_delay;
        eg_md.num_snapshots_to_read= (bit<8>) (q_delay >> {{LOG_CQ_T}});
        // Note: in P4_14 the delay in queuing meta is 32 bits. In P4_16, to recover 32-bit queuing delay, you need to manually bridge a longer timestamp from ingress. 
    
        bit<48> d_i=eg_intr_md_from_prsr.global_tstamp;
        //bit<18> a_i=eg_intr_md.enq_tstamp;
        eg_md.snap_epoch=d_i[{{LOG_CQ_T}}+{{LOG_CQ_H}}-1:{{LOG_CQ_T}}];
        // floor(d_i / T) % h
    }
    
    action prep_reads(){
        {% for snapID in range(CQ_H) %}
            {% for rowID in range(CQ_R) %}
                eg_md.snap_{{snapID}}_row_{{rowID}}_read=0;
            {% endfor %}
        {% endfor %}
    }
    
    Random< bit<8> >() rng;
    action prep_random(){
        eg_md.random_bits = rng.get();
    }
    
    //== Prepare register access index options
    Register<bit<32>,_>(1) reg_cleaning_index;
    RegisterAction<bit<32>, _, bit<32>>(reg_cleaning_index) reg_cleaning_index_rw = {
        void apply(inout bit<32> val, out bit<32> rv) {
            rv = val;
            val = val + 1;
        }
    };
    action calc_cyclic_index(){
        eg_md.cyclic_index = (bit<{{LOG_CQ_C}}>) reg_cleaning_index_rw.execute(0);
    }
    
    {% for rowID in range(CQ_R) %}
        Hash<bit<{{LOG_CQ_C}}>>(HashAlgorithm_t.CRC32) hash_{{rowID}}_TCP;  
        Hash<bit<{{LOG_CQ_C}}>>(HashAlgorithm_t.CRC32) hash_{{rowID}}_UDP;  
        Hash<bit<{{LOG_CQ_C}}>>(HashAlgorithm_t.CRC32) hash_{{rowID}}_Other;   
    {% endfor %}
       
    action calc_hashed_index_TCP(){
       {% for rowID in range(CQ_R) %}
           eg_md.hashed_index_row_{{rowID}} = hash_{{rowID}}_TCP.get({
               {{get_seed()}}, hdr.ipv4.src_addr,
               {{get_seed()}}, hdr.ipv4.dst_addr,
               {{get_seed()}}, hdr.tcp.src_port,
               {{get_seed()}}, hdr.tcp.dst_port
           });
       {% endfor %}
    }
    action calc_hashed_index_UDP(){
       {% for rowID in range(CQ_R) %}
           eg_md.hashed_index_row_{{rowID}} = hash_{{rowID}}_UDP.get({
               {{get_seed()}}, hdr.ipv4.src_addr,
               {{get_seed()}}, hdr.ipv4.dst_addr,
               {{get_seed()}}, hdr.udp.src_port,
               {{get_seed()}}, hdr.udp.dst_port
           });
       {% endfor %}
    }
    action calc_hashed_index_Other(){
       {% for rowID in range(CQ_R) %}
           eg_md.hashed_index_row_{{rowID}} = hash_{{rowID}}_Other.get({
               {{get_seed()}}, hdr.ipv4.src_addr,
               {{get_seed()}}, hdr.ipv4.dst_addr,
               {{get_seed()}}, hdr.ipv4.protocol
           });
       {% endfor %}
    }
    
    
    //== Deciding on using hashed-based or cyclic-based index
    {% for snapID in range(CQ_H) %}
        action snap_{{snapID}}_select_index_hash(){
            {% for rowID in range(CQ_R) %}
                eg_md.snap_{{snapID}}_row_{{rowID}}_index=eg_md.cyclic_index;
            {% endfor %}
        }
        action snap_{{snapID}}_select_index_cyclic(){
            {% for rowID in range(CQ_R) %}
                eg_md.snap_{{snapID}}_row_{{rowID}}_index=eg_md.hashed_index_row_{{rowID}};
            {% endfor %}
        }
        table tb_snap_{{snapID}}_select_index {
            key = {
                eg_md.snap_epoch: exact;
            }
            actions = {
                snap_{{snapID}}_select_index_hash;
                snap_{{snapID}}_select_index_cyclic;
            }
            size = 2;
            default_action = snap_{{snapID}}_select_index_hash();
            const entries = {
               {{snapID}} : snap_{{snapID}}_select_index_cyclic();
            }
        }
    {% endfor %}
    
    
    //== Prepare snapshot register access actions 
    {% for snapID in range(CQ_H) %}
        {% for rowID in range(CQ_R) %}
            Register<bit<32>,_>({{CQ_C}}) snap_{{snapID}}_row_{{rowID}};
            RegisterAction<bit<32>, _, bit<32>> (snap_{{snapID}}_row_{{rowID}}) snap_{{snapID}}_row_{{rowID}}_read = {
                    void apply(inout bit<32> val, out bit<32> rv) {
                        rv = val;
                    }
                };
            action regexec_snap_{{snapID}}_row_{{rowID}}_read(){
                eg_md.snap_{{snapID}}_row_{{rowID}}_read=snap_{{snapID}}_row_{{rowID}}_read.execute(eg_md.snap_{{snapID}}_row_{{rowID}}_index);
            }
            RegisterAction<bit<32>, _, bit<32>> (snap_{{snapID}}_row_{{rowID}}) snap_{{snapID}}_row_{{rowID}}_inc = {
                    void apply(inout bit<32> val, out bit<32> rv) {
                        val = val + SKETCH_INC;
                        rv = val;
                    }
                };
            action regexec_snap_{{snapID}}_row_{{rowID}}_inc(){
                eg_md.snap_{{snapID}}_row_{{rowID}}_read=snap_{{snapID}}_row_{{rowID}}_inc.execute(eg_md.snap_{{snapID}}_row_{{rowID}}_index);
            }
            RegisterAction<bit<32>, _, bit<32>> (snap_{{snapID}}_row_{{rowID}}) snap_{{snapID}}_row_{{rowID}}_clr = {
                    void apply(inout bit<32> val, out bit<32> rv) {
                        val = 0;
                        rv = 0;
                    }
                };
            action regexec_snap_{{snapID}}_row_{{rowID}}_clr(){
                snap_{{snapID}}_row_{{rowID}}_clr.execute(eg_md.snap_{{snapID}}_row_{{rowID}}_index);
            }
            table tb_snap_{{snapID}}_row_{{rowID}}_rr {
                key = {
                    eg_md.snap_epoch: exact;
                    eg_md.num_snapshots_to_read: range;
                }
                actions = {
                    regexec_snap_{{snapID}}_row_{{rowID}}_read;
                    regexec_snap_{{snapID}}_row_{{rowID}}_inc;
                    regexec_snap_{{snapID}}_row_{{rowID}}_clr;
                    nop;
                }
                size = {{ (CQ_H*CQ_H+1) }};
                default_action = nop();
                //round-robin logic
                const entries = {
                    ({{  snapID  }}, 0..255) : regexec_snap_{{snapID}}_row_{{rowID}}_clr;
                    ({{ (snapID+1)%CQ_H}}, 0..255) : regexec_snap_{{snapID}}_row_{{rowID}}_inc;
                    //read only when qlen/T is large enough (otherwise leave 0)
                    {% for i in range(2,CQ_H) %}
                        ({{ (snapID+i)%CQ_H}}, {{i-1}}..255) : regexec_snap_{{snapID}}_row_{{rowID}}_read;
                    {% endfor %}
                }
            }
        {% endfor %}
    {% endfor %}  
    
    //== Folding sums, which can't be written inline 
    {% for layer in range(LOG_CQ_H) %}
        {% set skip = 2**layer %}
        {% for j in range(0,CQ_H, skip*2) %}
            action calc_sum_{{j}}_l{{layer}}(){
                eg_md.snap_{{j}}_read_min_l{{layer+1}} = 
                eg_md.snap_{{j}}_read_min_l{{layer}} + eg_md.snap_{{j+skip}}_read_min_l{{layer}};
            }
        {% endfor %}  
    {% endfor %}  
    
    //== Finally, actions based on flow size in the queue
    table tb_per_flow_action {
        key = {
            {{snap_read_sum}}[26:10]: range; //scale down to 16 bits
            eg_md.q_delay: range;
            eg_md.random_bits: range;
            hdr.ipv4.ecn : exact;
        }
        actions = {
            nop;
            drop;
            mark_ECN;
        }
        default_action = nop();
        // const entries = {  }
    }
    
    apply {
        tb_gatekeeper.apply();
        
        // Startup
        prep_epochs();
        prep_reads();
        prep_random();
        
        // Index for sketch cleaning and read/write
        calc_cyclic_index();
        if(hdr.ipv4.protocol==IP_PROTOCOLS_TCP){
            calc_hashed_index_TCP();
        }else if(hdr.ipv4.protocol==IP_PROTOCOLS_UDP){
            calc_hashed_index_UDP();
        }else{
            calc_hashed_index_Other();
        }
        
        // Select index for snapshots. Cyclic for cleaning, hashed for read/inc
        {% for snapID in range(CQ_H) %}
            tb_snap_{{snapID}}_select_index.apply();
        {% endfor %}   
        
        // Run the snapshots! Round-robin clean, inc, read
        {% for snapID in range(CQ_H) %}
            {% for rowID in range(CQ_R) %}
                 tb_snap_{{snapID}}_row_{{rowID}}_rr.apply();
            {% endfor %}
        {% endfor %}   
        
        // Calc min across rows (as in count-"min" sketch)
        {% for snapID in range(CQ_H) %}
            {% if CQ_R == 1 %} 
                eg_md.snap_{{snapID}}_read_min_l0=eg_md.snap_{{snapID}}_row_0_read; //only one row, no need to min? 
            {% else %}
                eg_md.snap_{{snapID}}_read_min_l0=min(eg_md.snap_{{snapID}}_row_0_read,eg_md.snap_{{snapID}}_row_1_read);
            {% endif %}
            {% for rowID in range(2,CQ_R) %}
                 eg_md.snap_{{snapID}}_read_min_l0=min(eg_md.snap_{{snapID}}_read_min_l0,eg_md.snap_{{snapID}}_row_{{rowID}}_read);
            {% endfor %}
        {% endfor %}   
        
        // Sum all reads together, using log(CQ_H) layers.
        {% for layer in range(LOG_CQ_H) %}
            {% set skip = 2**layer %}
            {% for j in range(0,CQ_H, skip*2) %}
                calc_sum_{{j}}_l{{layer}}();
            {% endfor %}  
        {% endfor %}  
        // bit<32> snap_read_sum={{snap_read_sum}};
        
        // With flow size in queue, can check for bursty flow and add AQM.
        tb_per_flow_action.apply();
    }
}


Pipeline(SwitchIngressParser(),
         SwitchIngress(),
         SwitchIngressDeparser(),
         SwitchEgressParser(),
         SwitchEgress(),
         SwitchEgressDeparser()
         ) pipe;

Switch(pipe) main;